// Copyright (c) 2021 PlanetScale Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package generator

import (
	"fmt"
	"runtime/debug"
	"strings"

	"google.golang.org/protobuf/compiler/protogen"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/runtime/protoimpl"
	"google.golang.org/protobuf/types/pluginpb"
)

type ObjectSet map[protogen.GoIdent]bool

func (o ObjectSet) String() string {
	return fmt.Sprintf("%#v", o)
}

func (o ObjectSet) Set(s string) error {
	idx := strings.LastIndexByte(s, '.')
	if idx < 0 {
		return fmt.Errorf("invalid object name: %q", s)
	}

	ident := protogen.GoIdent{
		GoImportPath: protogen.GoImportPath(s[0:idx]),
		GoName:       s[idx+1:],
	}

	o[ident] = true
	return nil
}

type Config struct {
	Poolable       ObjectSet
	Wrap           bool
	AllowEmpty     bool
}

type Generator struct {
	plugin   *protogen.Plugin
	cfg      *Config
	features []Feature
	local    map[protoreflect.FullName]bool
}

const SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)

func NewGenerator(plugin *protogen.Plugin, featureNames []string, cfg *Config) (*Generator, error) {
	plugin.SupportedFeatures = SupportedFeatures

	features, err := findFeatures(featureNames)
	if err != nil {
		return nil, err
	}

	local := make(map[protoreflect.FullName]bool)
	for _, f := range plugin.Files {
		if f.Generate {
			local[f.Desc.Package()] = true
		}
	}

	return &Generator{
		plugin:   plugin,
		cfg:      cfg,
		features: features,
		local:    local,
	}, nil
}

func (gen *Generator) Generate() {
	for _, file := range gen.plugin.Files {
		if !file.Generate {
			continue
		}

		var importPath protogen.GoImportPath
		if !gen.cfg.Wrap {
			importPath = file.GoImportPath
		}

		gf := gen.plugin.NewGeneratedFile(file.GeneratedFilenamePrefix+"_vtproto.pb.go", importPath)
		gen.generateFile(gf, file)
	}
}

func (gen *Generator) generateFile(gf *protogen.GeneratedFile, file *protogen.File) {
	p := &GeneratedFile{
		GeneratedFile: gf,
		Config:        gen.cfg,
		LocalPackages: gen.local,
	}

	p.P("// Code generated by protoc-gen-go-vtproto. DO NOT EDIT.")
	if bi, ok := debug.ReadBuildInfo(); ok {
		p.P("// protoc-gen-go-vtproto version: ", bi.Main.Version)
	}
	p.P("// source: ", file.Desc.Path())
	p.P()
	p.P("package ", file.GoPackageName)
	p.P()

	protoimplPackage := protogen.GoImportPath("google.golang.org/protobuf/runtime/protoimpl")
	p.P("const (")
	p.P("// Verify that this generated code is sufficiently up-to-date.")
	p.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimpl.GenVersion, " - ", protoimplPackage.Ident("MinVersion"), ")")
	p.P("// Verify that runtime/protoimpl is sufficiently up-to-date.")
	p.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimplPackage.Ident("MaxVersion"), " - ", protoimpl.GenVersion, ")")
	p.P(")")
	p.P()

	if p.Wrapper() {
		for _, msg := range file.Messages {
			p.P(`type `, msg.GoIdent.GoName, ` `, msg.GoIdent)
			for _, one := range msg.Oneofs {
				for _, field := range one.Fields {
					p.P(`type `, field.GoIdent.GoName, ` `, field.GoIdent)
				}
			}
		}
	}

	var generated bool
	for _, feat := range gen.features {
		featGenerator := feat(p)
		if featGenerator.GenerateFile(file) {
			generated = true
		}
	}

	if !generated && !gen.cfg.AllowEmpty {
		gf.Skip()
	}
}
